import json
import random

with open('./medical.json', 'r', encoding='utf-8') as f:
    data = json.load(f)

train_path = 'medical_train1.json'
val_path = 'medical_val1.json'
test_path = 'medical_test1.json'

train_list, val_list, test_list = [], [], []
for img_info in data:
    r = random.randint(1, 10)
    if r < 8:
        train_list.append(img_info)
    elif r < 9:
        val_list.append(img_info)
    else:
        test_list.append(img_info)

with open(train_path, 'w', encoding='utf-8') as f1:
    json.dump(train_list, f1)
with open(val_path, 'w', encoding='utf-8') as f2:
    json.dump(val_list, f2)
with open(test_path, 'w', encoding='utf-8') as f3:
    json.dump(test_list, f3)
